df <- readRDS("../data/models/social-risk-crash-rate-data.rds")
YEARI will treat year as a factor and one-hot encode it.
df$year <- as.factor(df$year)
year_dummies <- model.matrix(~ year - 1, data = df)
df <- cbind(df[ , !(names(df) %in% c("year"))], year_dummies)
We will remove all possible target variables and keep only one per model training.
# Choose your target variable (e.g., crash rate per 1,000 residents)
target_var <- "crash_rate_per_1000"
# Remove all target variables except selected
cols_to_remove <- grep("per_1000",
names(df),
value = TRUE)
cols_to_remove <- setdiff(cols_to_remove, target_var) # keep this column
df <- df %>% select(-all_of(cols_to_remove),)
# Create feature matrix and target vector
X <- df %>% select(-target_var, -borough, -total_population, -geoid)
y <- df[[target_var]]
glimpse(X)
Rows: 13,518
Columns: 49
$ pct_male_population <dbl> 12.460865, 11.694881, 12.229106…
$ pct_female_population <dbl> 10.846132, 11.629654, 11.054169…
$ pct_white_population <dbl> 4.1225202, 3.6815086, 2.5856754…
$ pct_black_population <dbl> 2.435429, 2.630197, 2.833673, 2…
$ pct_asian_population <dbl> 0.19803330, 0.14388387, 0.23376…
$ pct_hispanic_population <dbl> 6.411328, 6.607147, 5.982424, 6…
$ pct_foreign_born <dbl> 2.909566, 2.442189, 2.187254, 2…
$ pct_age_under_18 <dbl> 2.130762, 2.643626, 2.333613, 2…
$ pct_age_18_34 <dbl> 1.715654, 1.653705, 1.882339, 1…
$ pct_age_35_64 <dbl> 3.296112, 3.079115, 3.041014, 3…
$ pct_age_65_plus <dbl> 1.80895804, 1.78607845, 1.56929…
$ median_income <dbl> 58582.658, 49964.513, 68000.000…
$ pct_income_under_25k <dbl> 0.5445916, 0.5544325, 0.5325841…
$ pct_income_25k_75k <dbl> 1.0568123, 1.1376418, 0.9533662…
$ pct_income_75k_plus <dbl> 0.9273290, 0.8824877, 1.2379531…
$ pct_below_poverty <dbl> 1.9574830, 1.9510653, 1.8091597…
$ median_gross_rent <dbl> 1579.1133, 1524.3577, 1701.0000…
$ pct_owner_occupied <dbl> 1.33318303, 1.35699756, 1.58439…
$ pct_renter_occupied <dbl> 1.2085934, 1.2305140, 1.1518859…
$ pct_no_vehicle <dbl> 0.5122207, 0.6522735, 0.6118619…
$ pct_less_than_hs <dbl> 1.4509748, 1.1951954, 1.1302166…
$ pct_hs_diploma <dbl> 1.5576081, 1.4196542, 1.2704773…
$ pct_some_college <dbl> 0.8473540, 0.8997538, 0.7256966…
$ pct_associates_degree <dbl> 0.4912749, 0.3280552, 0.3923234…
$ pct_bachelors_degree <dbl> 0.9482748, 0.9457966, 0.8537607…
$ pct_graduate_degree <dbl> 0.3998749, 0.7098271, 1.1037907…
$ pct_in_labor_force <dbl> 3.566504, 3.430191, 3.555304, 3…
$ pct_not_in_labor_force <dbl> 3.448445, 3.326595, 3.162980, 3…
$ unemployment_rate <dbl> 15.750133, 13.478747, 10.806175…
$ pct_commute_short <dbl> 0.7350082, 0.4604284, 0.1585556…
$ pct_commute_medium <dbl> 1.7061331, 1.6882374, 1.2521824…
$ pct_commute_long <dbl> 2.477320, 2.685832, 3.553271, 3…
$ pct_carpool <dbl> 0.35798327, 0.31462606, 0.00000…
$ pct_public_transit <dbl> 2.056500, 2.540030, 2.898721, 2…
$ pct_walk <dbl> 0.37702494, 0.30695226, 0.13009…
$ pct_bike <dbl> 0.00000000, 0.00000000, 0.00000…
$ pct_work_from_home <dbl> 0.01713750, 0.04604284, 0.04675…
$ pct_vehicle <dbl> 2.0165122, 1.9222885, 2.1120415…
$ post_pandemic <int> 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1…
$ poverty_vehicle_interaction <dbl> 3.9472883, 3.7505104, 3.8210202…
$ unemployment_vehicle_interaction <dbl> 31.760336, 25.910041, 22.823090…
$ workers_long_commute_interaction <dbl> 8.835372, 9.212919, 12.632957, …
$ homeowners_long_commute_interaction <dbl> 3.3027216, 3.6446678, 5.6297917…
$ year2018 <dbl> 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0…
$ year2019 <dbl> 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0…
$ year2020 <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0…
$ year2021 <dbl> 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0…
$ year2022 <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1…
$ year2023 <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0…
What This Does - Optuna (Python) handles the search
space and Bayesian optimization. - The final best parameters are applied
to fit the final_model. - Search space:
Instead of predefined grids, trial$suggest_float() and
trial$suggest_int() explore a range of values. -
Best parameters: study$best_params holds
the optimal hyperparameters.
## CONVERT TO DMATRIX
dtrain_all <- xgb.DMatrix(data = as.matrix(X), label = y)
## Start Python venv
reticulate::use_virtualenv("r-reticulate", required = TRUE)
## OPTUNA-BASED SPATIAL CV
optuna <- import("optuna")
boroughs <- unique(df$borough)
folds <- lapply(boroughs, function(b) which(df$borough != b))
# Optuna objective
objective <- function(trial) {
params <- list(
booster = "gbtree",
eta = trial$suggest_float("eta", 0.01, 0.3, log = TRUE),
max_depth = trial$suggest_int("max_depth", 3, 12),
min_child_weight = trial$suggest_int("min_child_weight", 1, 10),
subsample = trial$suggest_float("subsample", 0.5, 1.0),
colsample_bytree = trial$suggest_float("colsample_bytree", 0.5, 1.0),
gamma = trial$suggest_float("gamma", 0, 10),
lambda = trial$suggest_float("lambda", 0, 10),
alpha = trial$suggest_float("alpha", 0, 10)
)
rmse_scores <- numeric(length(folds))
for (i in seq_along(folds)) {
train_idx <- folds[[i]]
valid_idx <- setdiff(seq_len(nrow(dtrain_all)), train_idx)
dtrain <- xgb.DMatrix(data = as.matrix(X[train_idx, ]), label = y[train_idx])
dvalid <- xgb.DMatrix(data = as.matrix(X[valid_idx, ]), label = y[valid_idx])
model <- xgb.train(
params = params,
data = dtrain,
nrounds = 500,
watchlist = list(val = dvalid),
early_stopping_rounds = 20,
verbose = 0
)
rmse_scores[i] <- min(model$evaluation_log$val_rmse)
}
preds <- predict(model, as.matrix(X[valid_idx, ]))
return(Metrics::rmse(y[valid_idx], preds))
}
# Run Optuna study
set.seed(2025)
study <- optuna$create_study(direction = "minimize")
study$optimize(objective, n_trials = 50)
best_params <- study$best_params
print(best_params)
$eta
[1] 0.04865258
$max_depth
[1] 7
$min_child_weight
[1] 3
$subsample
[1] 0.9616739
$colsample_bytree
[1] 0.8674877
$gamma
[1] 9.91113
$lambda
[1] 1.266656
$alpha
[1] 2.7638
# Set seed
set.seed(2025)
# Split by index
train_index <- createDataPartition(y, p = 0.8, list = FALSE)
X_train <- X[train_index, ]
y_train <- y[train_index]
X_test <- X[-train_index, ]
y_test <- y[-train_index]
# Convert to xgb.DMatrix
dtrain <- xgb.DMatrix(data = as.matrix(X_train), label = y_train)
dtest <- xgb.DMatrix(data = as.matrix(X_test), label = y_test)
# Set seed
set.seed(2025)
# Training with parallel processing
final_model <- xgb.train(
params = list(
eta = best_params$eta,
max_depth = best_params$max_depth,
gamma = best_params$gamma,
colsample_bytree = best_params$colsample_bytree,
min_child_weight = best_params$min_child_weight,
subsample = best_params$subsample,
objective = "reg:squarederror",
eval_metric = "rmse"
),
data = dtrain,
nrounds = 1000,
watchlist = list(train = dtrain, test = dtest),
early_stopping_rounds = 20,
verbose = 1,
nthread = detectCores() - 1
)
[1] train-rmse:2.007458 test-rmse:1.993689
Multiple eval metrics are present. Will use test_rmse for early stopping.
Will train until test_rmse hasn't improved in 20 rounds.
[2] train-rmse:1.962899 test-rmse:1.959074
[3] train-rmse:1.921029 test-rmse:1.922708
[4] train-rmse:1.886514 test-rmse:1.899632
[5] train-rmse:1.854028 test-rmse:1.872285
[6] train-rmse:1.817282 test-rmse:1.843062
[7] train-rmse:1.780462 test-rmse:1.821862
[8] train-rmse:1.750987 test-rmse:1.806139
[9] train-rmse:1.717576 test-rmse:1.786795
[10] train-rmse:1.689064 test-rmse:1.765718
[11] train-rmse:1.663708 test-rmse:1.746783
[12] train-rmse:1.635800 test-rmse:1.732048
[13] train-rmse:1.610693 test-rmse:1.719750
[14] train-rmse:1.588118 test-rmse:1.707063
[15] train-rmse:1.565416 test-rmse:1.695671
[16] train-rmse:1.546982 test-rmse:1.684235
[17] train-rmse:1.526407 test-rmse:1.671553
[18] train-rmse:1.507937 test-rmse:1.662180
[19] train-rmse:1.491254 test-rmse:1.654204
[20] train-rmse:1.474855 test-rmse:1.646850
[21] train-rmse:1.456319 test-rmse:1.639243
[22] train-rmse:1.439692 test-rmse:1.631002
[23] train-rmse:1.424415 test-rmse:1.624349
[24] train-rmse:1.410186 test-rmse:1.618293
[25] train-rmse:1.394763 test-rmse:1.611141
[26] train-rmse:1.382552 test-rmse:1.605512
[27] train-rmse:1.367447 test-rmse:1.597054
[28] train-rmse:1.356154 test-rmse:1.592005
[29] train-rmse:1.344793 test-rmse:1.585211
[30] train-rmse:1.333928 test-rmse:1.581348
[31] train-rmse:1.323969 test-rmse:1.576604
[32] train-rmse:1.312801 test-rmse:1.571761
[33] train-rmse:1.302098 test-rmse:1.566222
[34] train-rmse:1.293532 test-rmse:1.562836
[35] train-rmse:1.284871 test-rmse:1.558357
[36] train-rmse:1.277237 test-rmse:1.554014
[37] train-rmse:1.268967 test-rmse:1.552095
[38] train-rmse:1.259250 test-rmse:1.549548
[39] train-rmse:1.250854 test-rmse:1.545687
[40] train-rmse:1.241821 test-rmse:1.544884
[41] train-rmse:1.234559 test-rmse:1.544940
[42] train-rmse:1.229263 test-rmse:1.541752
[43] train-rmse:1.220873 test-rmse:1.539150
[44] train-rmse:1.213864 test-rmse:1.537163
[45] train-rmse:1.205962 test-rmse:1.532867
[46] train-rmse:1.200961 test-rmse:1.529658
[47] train-rmse:1.193637 test-rmse:1.528302
[48] train-rmse:1.187168 test-rmse:1.527083
[49] train-rmse:1.180090 test-rmse:1.525951
[50] train-rmse:1.174782 test-rmse:1.523649
[51] train-rmse:1.170238 test-rmse:1.521110
[52] train-rmse:1.162460 test-rmse:1.519503
[53] train-rmse:1.157345 test-rmse:1.516278
[54] train-rmse:1.153607 test-rmse:1.515619
[55] train-rmse:1.150669 test-rmse:1.514745
[56] train-rmse:1.146366 test-rmse:1.513725
[57] train-rmse:1.141967 test-rmse:1.512760
[58] train-rmse:1.134378 test-rmse:1.510593
[59] train-rmse:1.128573 test-rmse:1.510020
[60] train-rmse:1.124535 test-rmse:1.510681
[61] train-rmse:1.119343 test-rmse:1.510245
[62] train-rmse:1.113330 test-rmse:1.508884
[63] train-rmse:1.105618 test-rmse:1.506367
[64] train-rmse:1.098194 test-rmse:1.506347
[65] train-rmse:1.094779 test-rmse:1.504716
[66] train-rmse:1.089849 test-rmse:1.503711
[67] train-rmse:1.082857 test-rmse:1.503017
[68] train-rmse:1.077175 test-rmse:1.501475
[69] train-rmse:1.072366 test-rmse:1.501128
[70] train-rmse:1.068529 test-rmse:1.500863
[71] train-rmse:1.063690 test-rmse:1.499950
[72] train-rmse:1.058948 test-rmse:1.499642
[73] train-rmse:1.051090 test-rmse:1.499222
[74] train-rmse:1.047310 test-rmse:1.499594
[75] train-rmse:1.044965 test-rmse:1.499160
[76] train-rmse:1.041559 test-rmse:1.497342
[77] train-rmse:1.037255 test-rmse:1.495467
[78] train-rmse:1.033923 test-rmse:1.495698
[79] train-rmse:1.027759 test-rmse:1.494544
[80] train-rmse:1.024731 test-rmse:1.494202
[81] train-rmse:1.021141 test-rmse:1.493573
[82] train-rmse:1.016536 test-rmse:1.491372
[83] train-rmse:1.010629 test-rmse:1.489853
[84] train-rmse:1.003788 test-rmse:1.489225
[85] train-rmse:1.001039 test-rmse:1.488963
[86] train-rmse:0.997133 test-rmse:1.487693
[87] train-rmse:0.993504 test-rmse:1.488640
[88] train-rmse:0.988923 test-rmse:1.487516
[89] train-rmse:0.985110 test-rmse:1.487273
[90] train-rmse:0.983987 test-rmse:1.486991
[91] train-rmse:0.981657 test-rmse:1.487075
[92] train-rmse:0.978000 test-rmse:1.485991
[93] train-rmse:0.975002 test-rmse:1.485134
[94] train-rmse:0.973392 test-rmse:1.484504
[95] train-rmse:0.970070 test-rmse:1.483558
[96] train-rmse:0.965876 test-rmse:1.482382
[97] train-rmse:0.962100 test-rmse:1.482339
[98] train-rmse:0.960346 test-rmse:1.481752
[99] train-rmse:0.957531 test-rmse:1.481287
[100] train-rmse:0.956131 test-rmse:1.480226
[101] train-rmse:0.954455 test-rmse:1.480473
[102] train-rmse:0.952365 test-rmse:1.480075
[103] train-rmse:0.949974 test-rmse:1.479913
[104] train-rmse:0.945196 test-rmse:1.478006
[105] train-rmse:0.939306 test-rmse:1.476034
[106] train-rmse:0.937217 test-rmse:1.475721
[107] train-rmse:0.934274 test-rmse:1.476059
[108] train-rmse:0.931232 test-rmse:1.476317
[109] train-rmse:0.929220 test-rmse:1.476118
[110] train-rmse:0.926741 test-rmse:1.475624
[111] train-rmse:0.924730 test-rmse:1.475058
[112] train-rmse:0.923732 test-rmse:1.474882
[113] train-rmse:0.921959 test-rmse:1.474408
[114] train-rmse:0.919593 test-rmse:1.473347
[115] train-rmse:0.917083 test-rmse:1.472874
[116] train-rmse:0.916204 test-rmse:1.472989
[117] train-rmse:0.914499 test-rmse:1.473238
[118] train-rmse:0.911374 test-rmse:1.472745
[119] train-rmse:0.908271 test-rmse:1.472332
[120] train-rmse:0.906324 test-rmse:1.471277
[121] train-rmse:0.904774 test-rmse:1.470961
[122] train-rmse:0.903580 test-rmse:1.470938
[123] train-rmse:0.902415 test-rmse:1.470566
[124] train-rmse:0.897758 test-rmse:1.469916
[125] train-rmse:0.897366 test-rmse:1.469692
[126] train-rmse:0.895669 test-rmse:1.469767
[127] train-rmse:0.894738 test-rmse:1.469355
[128] train-rmse:0.890474 test-rmse:1.468058
[129] train-rmse:0.887864 test-rmse:1.467617
[130] train-rmse:0.885567 test-rmse:1.467989
[131] train-rmse:0.881665 test-rmse:1.467039
[132] train-rmse:0.880868 test-rmse:1.466528
[133] train-rmse:0.879257 test-rmse:1.465947
[134] train-rmse:0.878440 test-rmse:1.465562
[135] train-rmse:0.878440 test-rmse:1.465563
[136] train-rmse:0.877177 test-rmse:1.465914
[137] train-rmse:0.875980 test-rmse:1.465894
[138] train-rmse:0.872677 test-rmse:1.465653
[139] train-rmse:0.872336 test-rmse:1.465572
[140] train-rmse:0.871586 test-rmse:1.465114
[141] train-rmse:0.871274 test-rmse:1.464773
[142] train-rmse:0.870926 test-rmse:1.464533
[143] train-rmse:0.869452 test-rmse:1.464345
[144] train-rmse:0.869119 test-rmse:1.464194
[145] train-rmse:0.867843 test-rmse:1.464281
[146] train-rmse:0.866683 test-rmse:1.463814
[147] train-rmse:0.864072 test-rmse:1.462787
[148] train-rmse:0.863431 test-rmse:1.462312
[149] train-rmse:0.861678 test-rmse:1.461565
[150] train-rmse:0.859769 test-rmse:1.461018
[151] train-rmse:0.857726 test-rmse:1.461043
[152] train-rmse:0.855316 test-rmse:1.460825
[153] train-rmse:0.855078 test-rmse:1.460858
[154] train-rmse:0.854531 test-rmse:1.460649
[155] train-rmse:0.854371 test-rmse:1.460506
[156] train-rmse:0.854054 test-rmse:1.460534
[157] train-rmse:0.853043 test-rmse:1.460167
[158] train-rmse:0.852387 test-rmse:1.459848
[159] train-rmse:0.850763 test-rmse:1.459622
[160] train-rmse:0.849740 test-rmse:1.459258
[161] train-rmse:0.847697 test-rmse:1.458956
[162] train-rmse:0.846688 test-rmse:1.459188
[163] train-rmse:0.846688 test-rmse:1.459189
[164] train-rmse:0.845541 test-rmse:1.459160
[165] train-rmse:0.844361 test-rmse:1.459259
[166] train-rmse:0.843876 test-rmse:1.459238
[167] train-rmse:0.843012 test-rmse:1.459397
[168] train-rmse:0.842695 test-rmse:1.459068
[169] train-rmse:0.841263 test-rmse:1.459216
[170] train-rmse:0.838951 test-rmse:1.458788
[171] train-rmse:0.837360 test-rmse:1.458522
[172] train-rmse:0.836775 test-rmse:1.458262
[173] train-rmse:0.834044 test-rmse:1.457664
[174] train-rmse:0.831470 test-rmse:1.457459
[175] train-rmse:0.830516 test-rmse:1.457276
[176] train-rmse:0.830015 test-rmse:1.456714
[177] train-rmse:0.828177 test-rmse:1.456465
[178] train-rmse:0.827930 test-rmse:1.456249
[179] train-rmse:0.826117 test-rmse:1.456285
[180] train-rmse:0.825291 test-rmse:1.456193
[181] train-rmse:0.825022 test-rmse:1.456237
[182] train-rmse:0.822976 test-rmse:1.455990
[183] train-rmse:0.819550 test-rmse:1.455594
[184] train-rmse:0.819134 test-rmse:1.455544
[185] train-rmse:0.818305 test-rmse:1.455587
[186] train-rmse:0.817262 test-rmse:1.456044
[187] train-rmse:0.814060 test-rmse:1.455725
[188] train-rmse:0.813500 test-rmse:1.455695
[189] train-rmse:0.813289 test-rmse:1.455571
[190] train-rmse:0.812503 test-rmse:1.455335
[191] train-rmse:0.811768 test-rmse:1.455235
[192] train-rmse:0.811557 test-rmse:1.455273
[193] train-rmse:0.810237 test-rmse:1.455007
[194] train-rmse:0.809739 test-rmse:1.454810
[195] train-rmse:0.808779 test-rmse:1.454504
[196] train-rmse:0.807546 test-rmse:1.454380
[197] train-rmse:0.805887 test-rmse:1.453924
[198] train-rmse:0.805232 test-rmse:1.453870
[199] train-rmse:0.804846 test-rmse:1.453646
[200] train-rmse:0.804846 test-rmse:1.453647
[201] train-rmse:0.804651 test-rmse:1.453469
[202] train-rmse:0.801963 test-rmse:1.452879
[203] train-rmse:0.800664 test-rmse:1.452671
[204] train-rmse:0.800308 test-rmse:1.452808
[205] train-rmse:0.798293 test-rmse:1.452213
[206] train-rmse:0.798136 test-rmse:1.452211
[207] train-rmse:0.797508 test-rmse:1.452808
[208] train-rmse:0.797301 test-rmse:1.452752
[209] train-rmse:0.796485 test-rmse:1.452395
[210] train-rmse:0.794158 test-rmse:1.451964
[211] train-rmse:0.793302 test-rmse:1.451972
[212] train-rmse:0.793079 test-rmse:1.451920
[213] train-rmse:0.792032 test-rmse:1.451697
[214] train-rmse:0.791850 test-rmse:1.452022
[215] train-rmse:0.791764 test-rmse:1.451928
[216] train-rmse:0.791686 test-rmse:1.451842
[217] train-rmse:0.791261 test-rmse:1.451764
[218] train-rmse:0.790861 test-rmse:1.451794
[219] train-rmse:0.789565 test-rmse:1.451629
[220] train-rmse:0.788573 test-rmse:1.451411
[221] train-rmse:0.788199 test-rmse:1.451300
[222] train-rmse:0.787727 test-rmse:1.451105
[223] train-rmse:0.787237 test-rmse:1.451147
[224] train-rmse:0.786367 test-rmse:1.451216
[225] train-rmse:0.785934 test-rmse:1.451206
[226] train-rmse:0.785084 test-rmse:1.451625
[227] train-rmse:0.784864 test-rmse:1.451209
[228] train-rmse:0.784685 test-rmse:1.451199
[229] train-rmse:0.784305 test-rmse:1.451162
[230] train-rmse:0.783932 test-rmse:1.450840
[231] train-rmse:0.781811 test-rmse:1.450500
[232] train-rmse:0.780730 test-rmse:1.450068
[233] train-rmse:0.778983 test-rmse:1.450461
[234] train-rmse:0.778983 test-rmse:1.450461
[235] train-rmse:0.778498 test-rmse:1.450519
[236] train-rmse:0.777903 test-rmse:1.450518
[237] train-rmse:0.775397 test-rmse:1.450578
[238] train-rmse:0.775397 test-rmse:1.450577
[239] train-rmse:0.775293 test-rmse:1.450432
[240] train-rmse:0.773335 test-rmse:1.450201
[241] train-rmse:0.773134 test-rmse:1.450122
[242] train-rmse:0.772223 test-rmse:1.449900
[243] train-rmse:0.770609 test-rmse:1.449612
[244] train-rmse:0.768555 test-rmse:1.448916
[245] train-rmse:0.768218 test-rmse:1.448935
[246] train-rmse:0.767148 test-rmse:1.449050
[247] train-rmse:0.767148 test-rmse:1.449049
[248] train-rmse:0.766173 test-rmse:1.448942
[249] train-rmse:0.765999 test-rmse:1.448702
[250] train-rmse:0.765615 test-rmse:1.448702
[251] train-rmse:0.765615 test-rmse:1.448703
[252] train-rmse:0.765615 test-rmse:1.448704
[253] train-rmse:0.765336 test-rmse:1.448703
[254] train-rmse:0.765073 test-rmse:1.448440
[255] train-rmse:0.764841 test-rmse:1.448256
[256] train-rmse:0.764679 test-rmse:1.448408
[257] train-rmse:0.764679 test-rmse:1.448408
[258] train-rmse:0.764249 test-rmse:1.448374
[259] train-rmse:0.764249 test-rmse:1.448374
[260] train-rmse:0.763734 test-rmse:1.448308
[261] train-rmse:0.763127 test-rmse:1.447987
[262] train-rmse:0.762913 test-rmse:1.448024
[263] train-rmse:0.762432 test-rmse:1.447797
[264] train-rmse:0.761881 test-rmse:1.447877
[265] train-rmse:0.761493 test-rmse:1.447667
[266] train-rmse:0.760016 test-rmse:1.447224
[267] train-rmse:0.760016 test-rmse:1.447223
[268] train-rmse:0.758637 test-rmse:1.446934
[269] train-rmse:0.758461 test-rmse:1.446729
[270] train-rmse:0.757796 test-rmse:1.446697
[271] train-rmse:0.757504 test-rmse:1.446520
[272] train-rmse:0.755662 test-rmse:1.446263
[273] train-rmse:0.753992 test-rmse:1.445813
[274] train-rmse:0.753258 test-rmse:1.445777
[275] train-rmse:0.751414 test-rmse:1.445279
[276] train-rmse:0.751414 test-rmse:1.445279
[277] train-rmse:0.751091 test-rmse:1.445330
[278] train-rmse:0.750699 test-rmse:1.445119
[279] train-rmse:0.750470 test-rmse:1.445060
[280] train-rmse:0.750298 test-rmse:1.444934
[281] train-rmse:0.749454 test-rmse:1.444948
[282] train-rmse:0.748608 test-rmse:1.445291
[283] train-rmse:0.747309 test-rmse:1.444728
[284] train-rmse:0.746919 test-rmse:1.444627
[285] train-rmse:0.746919 test-rmse:1.444626
[286] train-rmse:0.746919 test-rmse:1.444628
[287] train-rmse:0.744860 test-rmse:1.444149
[288] train-rmse:0.742424 test-rmse:1.443623
[289] train-rmse:0.741312 test-rmse:1.443528
[290] train-rmse:0.739679 test-rmse:1.443200
[291] train-rmse:0.739679 test-rmse:1.443200
[292] train-rmse:0.739280 test-rmse:1.442871
[293] train-rmse:0.738191 test-rmse:1.443037
[294] train-rmse:0.737708 test-rmse:1.443045
[295] train-rmse:0.737351 test-rmse:1.442922
[296] train-rmse:0.736442 test-rmse:1.442800
[297] train-rmse:0.736244 test-rmse:1.442868
[298] train-rmse:0.736243 test-rmse:1.442869
[299] train-rmse:0.735299 test-rmse:1.442860
[300] train-rmse:0.734449 test-rmse:1.442569
[301] train-rmse:0.734356 test-rmse:1.442586
[302] train-rmse:0.734356 test-rmse:1.442586
[303] train-rmse:0.733719 test-rmse:1.442852
[304] train-rmse:0.733482 test-rmse:1.442839
[305] train-rmse:0.733016 test-rmse:1.443598
[306] train-rmse:0.733016 test-rmse:1.443599
[307] train-rmse:0.733016 test-rmse:1.443600
[308] train-rmse:0.732096 test-rmse:1.443293
[309] train-rmse:0.731808 test-rmse:1.443311
[310] train-rmse:0.731808 test-rmse:1.443312
[311] train-rmse:0.731364 test-rmse:1.443298
[312] train-rmse:0.731364 test-rmse:1.443298
[313] train-rmse:0.731364 test-rmse:1.443298
[314] train-rmse:0.731139 test-rmse:1.443332
[315] train-rmse:0.731139 test-rmse:1.443335
[316] train-rmse:0.731139 test-rmse:1.443333
[317] train-rmse:0.730986 test-rmse:1.443330
[318] train-rmse:0.730987 test-rmse:1.443329
[319] train-rmse:0.730005 test-rmse:1.443183
[320] train-rmse:0.730004 test-rmse:1.443184
Stopping. Best iteration:
[300] train-rmse:0.734449 test-rmse:1.442569
# Create directory if it doesn't exist
if (!dir.exists("../data/models")) {
dir.create("../data/models", recursive = TRUE)
}
# Save the final XGBoost model
saveRDS(final_model, file = "../data/models/xgb_model.rds")
# Save the best parameters
saveRDS(best_params, file = "../data/models/xgb_best_params.rds")
cat("Model and parameters saved to ../data/models/")
Model and parameters saved to ../data/models/
# Load the final XGBoost model
final_model <- readRDS("../data/models/xgb_model.rds")
# Load the best parameters
best_params <- readRDS("../data/models/xgb_best_params.rds")
library(Metrics)
library(ggplot2)
library(dplyr)
set.seed(2025)
# Predict on test set
preds <- predict(final_model, as.matrix(X_test))
# --- Metrics ---
rmse <- sqrt(mean((y_test - preds)^2))
mae <- mean(abs(y_test - preds))
mape <- mean(abs((y_test - preds) / y_test)) * 100
r2 <- 1 - (sum((y_test - preds)^2) / sum((y_test - mean(y_test))^2))
cat("Model Evaluation Metrics:\n")
Model Evaluation Metrics:
cat(" RMSE:", rmse, "\n")
RMSE: 1.442569
cat(" MAE :", mae, "\n")
MAE : 0.7737608
cat(" MAPE:", mape, "%\n")
MAPE: Inf %
cat(" R² :", r2, "\n\n")
R² : 0.3582584
# --- Residuals ---
residuals <- y_test - preds
residual_df <- data.frame(
actual = y_test,
predicted = preds,
residuals = residuals
)
# --- Plot: Predicted vs Actual ---
p1 <- residual_df %>%
ggplot(aes(x = actual, y = predicted)) +
geom_point(alpha = 0.5) +
geom_abline(slope = 1, intercept = 0, color = "red") +
theme_minimal() +
labs(title = "Predicted vs Actual Crash Rates",
x = "Actual",
y = "Predicted")
# --- Plot: Residuals vs Predicted ---
p2 <- residual_df %>%
ggplot(aes(x = predicted, y = residuals)) +
geom_point(alpha = 0.5, color = "blue") +
geom_hline(yintercept = 0, linetype = "dashed", color = "red") +
theme_minimal() +
labs(title = "Residuals vs Predicted",
x = "Predicted",
y = "Residuals")
# --- Plot: Residual Density ---
# Residual Histogram
p3 <- ggplot(residual_df, aes(x = residuals)) +
geom_histogram(binwidth = 0.2, fill = "steelblue", color = "white") +
geom_density(color = "red") +
theme_minimal() +
labs(title = "Residual Distribution", x = "Residuals", y = "Count")
# Print plots
print(p1)
print(p2)
print(p3)
ggsave("../report/plots/predicted_vs_actual_values_plot.png", p1, width = 10, height = 6, dpi = 300)
ggsave("../report/plots/resisuals_vs_predicted_values_plot.png", p2, width = 10, height = 6, dpi = 300)
ggsave("../report/plots/residual_density_plot.png", p3, width = 10, height = 6, dpi = 300)
# Compute SHAP values
shap_values <- shap.values(xgb_model = final_model, X_train = as.matrix(X_train))
shap_long <- shap.prep(shap_contrib = shap_values$shap_score, X_train = as.matrix(X_train))
# SHAP summary plot
print(shap.plot.summary(shap_long))
if (!dir.exists("../report/plots")) {
dir.create("../report/plots")
}
shap <- shap.plot.summary(shap_long)
ggsave("../report/plots/shap_summary_plot.png", shap, width = 10, height = 6, dpi = 300)
xgb.plot.tree(model = final_model, trees = 0)
xgb.plot.tree(model = final_model, trees = 1)
xgb.plot.tree(model = final_model, trees = 2)
xgb.plot.multi.trees(model = final_model)
# ============================================================
# Additional Model Diagnostics and Deeper Analysis
# ============================================================
library(ggplot2)
library(dplyr)
library(pdp) # For Partial Dependence Plots
library(DALEX) # For model explainability
library(ggthemes)
library(sf)
# ---------------------------
# 1. SHAP Dependence and Interaction Plots
# ---------------------------
message("\nGenerating SHAP dependence and interaction plots...")
# Assuming shap_values and shap_long are already computed
# (if not, recompute them using iml or SHAPforxgboost packages)
# Top feature by SHAP importance
top_feature <- shap_long %>%
as_tibble() %>%
count(variable, wt = abs(value), sort = TRUE) %>%
dplyr::slice(1) %>%
pull(variable)
# Dependence plot for top feature
shap.plot.dependence(data_long = shap_long, x = top_feature, color_feature = top_feature)
# Interaction values
shap_interaction_values <- predict(
final_model,
as.matrix(X_train),
predinteraction = TRUE
)
# shap_interaction_values will be a 3D array: [n_samples, n_features, n_features]
interactions <- dim(shap_interaction_values)
# ---------------------------
# 3. Partial Dependence Plots (PDP)
# ---------------------------
message("\nGenerating Partial Dependence Plots...")
top_features <- shap_long %>%
count(variable, wt = abs(value), sort = TRUE) %>%
dplyr::slice(1:10) %>%
pull(variable)
# Ensure the output directory exists
dir.create("../report/plots/pdp", recursive = TRUE, showWarnings = FALSE)
for (f in top_features) {
pd <- partial(final_model, pred.var = f, train = as.matrix(X_train), grid.resolution = 30)
p <- plot(pd, main = paste("Partial Dependence of", f))
# Save as PNG
ggsave(
filename = paste0("../report/plots/pdp/pdp_", f, ".png"),
plot = p,
width = 8,
height = 6,
dpi = 300
)
}